# -*- coding: utf-8 -*-
"""
   File Name:  models.py
   Author :    liccoo
   Time:       2022/8/24 12:44
"""
from torch import nn


class ForwardModel(nn.Module):
    def __init__(self, num_labels=1):
        super(ForwardModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(num_labels, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, num_labels),
        )

    def forward(self, x):
        y = self.fc(x)
        return y


class InverseModel(nn.Module):
    def __init__(self, num_labels=1):
        super(InverseModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(num_labels, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, num_labels),
        )

    def forward(self, y):
        x = self.fc(y)
        return x
